# Kernel: sgemm_tn_128x128

# Copyright 2014 Nervana Systems Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#    http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


<CONSTANT_MAPPING>
    addr_zero  : 4x<128*8*4>

    gridDimA : c[0x0][0x14]
    gridDimB : c[0x0][0x18]

    param_Rand[0]   : c[0x0][0x140]
    param_Rand[1]   : c[0x0][0x144]
    param_A[0]      : c[0x0][0x148]
    param_A[1]      : c[0x0][0x14c]
    param_B[0]      : c[0x0][0x150]
    param_B[1]      : c[0x0][0x154]
    param_C[0]      : c[0x0][0x158]
    param_C[1]      : c[0x0][0x15c]
    param_lda8      : c[0x0][0x160]
    param_ldb8      : c[0x0][0x164]
    param_ldc       : c[0x0][0x168]
    param_m         : c[0x0][0x16c]
    param_n         : c[0x0][0x170]
    param_k         : c[0x0][0x174]
    param_alpha     : c[0x0][0x178]
    param_beta      : c[0x0][0x17c]
    param_flags     : c[0x0][0x180]
    param_ldaz      : c[0x0][0x184]
    param_ldbz      : c[0x0][0x188]
    param_ldcz      : c[0x0][0x18c]
    param_loops     : c[0x0][0x190]
</CONSTANT_MAPPING>

<REGISTER_MAPPING>

    64-95   ~ blkA, blkB, blkZ, lda, ldb, ldaz, ldbz, tid1, tid7, tidX, blk, tid31, tid128, dimA, flag, tbid, param_A_base_lo, param_A_base_hi, ta_shl, param_B_base_lo, param_B_base_hi,tb_shl, tc_shl, k_and

    0-63    : czero<00-63>

     1, 4,17,20,33,36,49,52 : cx<0-7>y0
     5, 0,21,16,37,32,53,48 : cx<0-7>y1
     3, 6,19,22,35,38,51,54 : cx<0-7>y2
     7, 2,23,18,39,34,55,50 : cx<0-7>y3
     9,12,25,28,41,44,57,60 : cx<0-7>y4
    13, 8,29,24,45,40,61,56 : cx<0-7>y5
    11,14,27,30,43,46,59,62 : cx<0-7>y6
    15,10,31,26,47,42,63,58 : cx<0-7>y7

    64-95   ~ x<1-3>, y<1-3>

    64-67   : j0Ay<0-3>
    72-75   : j0Ay<4-7>
    68-71   : j0Bx<0-3>
    76-79   : j0Bx<4-7>
    80-83   : j1Ay<0-3>
    88-91   : j1Ay<4-7>
    84-87   : j1Bx<0-3>
    92-95   : j1Bx<4-7>

    96-103  : loadA<0-3>, loadB<0-3>

    104-107 : trackA<0-1>, trackB<0-1>

    108-121 ~ writeS, lda8, k, tidY, txa, txb, ta, tb, loop
    #108-116 ~ writeS, lda8, k, tidY, txa, txb, ta, tb, loop
    #122-125 ~ readAs, readBs, tid, seed
    125     ~ readAs
    124     ~ readBs
    123     ~ tid
    122     ~ seed
    126-127 : Rand<0-1>

    64-75   ~ ldc, ci, xmad_c, clk_shf1, clk_shf2, tid_31, tid_96, tid_128, mantissa_bits, blockA, blockB, blockZ
    64-75   : c<0-7>, d3, d2, d1, d0
    76-85   : C00y<0-1>, C04y<0-1>, C08y<0-1>, C12y<0-1>
    86-121  ~ ldc1, ldc4, ldc60, ldcz, writeCs, readCs, cx<00|64>, cy<00|04|08|12>, alpha, beta, flags, exp_scale, trunc_mask, lfsr<0-2>, exp<0-3>, rand<0-3>, lfsr<0-2>_1, lfsr<0-2>_2
</REGISTER_MAPPING>

-:-:-:-:00 S2R tid, SR_TID.X;
-:-:-:-:00 S2R blkA, SR_CTAID.Y;
-:-:-:-:00 S2R blkB, SR_CTAID.Z;
-:-:-:-:00 S2R blkZ, SR_CTAID.X;

//<SCHEDULE_BLOCK>
-:-:-:-:00 MOV k, param_k;
-:-:-:-:00 MOV loop, RZ;
-:-:-:-:00 STS.128 [RZ + addr_zero], RZ;
<CODE>
join('', map sprintf("-:-:-:-:00 LDS.U.128 czero%02d, [RZ + addr_zero];\n", $_ * 4), 0..15);
</CODE>

// tidX = (tid & 31) << 2
// tidY = (tid >> 5) & 7
-:-:-:-:00 LOP.AND tid31, tid, 31;
-:-:-:-:00 SHL tidX, tid31, 2;
-:-:-:-:00 BFE.U32 tidY, tid, 0x305; // 3 bits at position 5

-:-:-:-:00 MOV lda, param_lda8;
-:-:-:-:00 MOV ldb, param_ldb8;
-:-:-:-:00 SHR.U32 lda, lda, 5;
-:-:-:-:00 SHR.U32 ldb, ldb, 5;
-:-:-:-:00 MOV ldaz, param_ldaz;
-:-:-:-:00 MOV ldbz, param_ldbz;

// trackA += (blkA*128 + lda*tidY + tidX) * 2
//txa = blkA*128 + tidX
-:-:-:-:00 ISCADD txa, blkA, tidX, 7;
//ta = lda*tidY + txa
//-:-:-:-:00 XMAD.LO2 ta, lda, tidY, txa;
// ta = ldaz*blkz + ta
//-:-:-:-:00 XMAD.LO2 ta, ldaz, blkZ, ta;
-:-:-:-:00 IMAD ta, lda, tidY, txa;
-:-:-:-:00 IMAD ta, ldaz, blkZ, ta;
//-:-:-:-:00 LEA trackA0.CC, ta, param_A[0], 0x2;
//-:-:-:-:00 LEA.HI.X trackA1, ta, param_A[1], RZ, 0x2;

-:-:-:-:00 MOV param_A_base_lo, param_A[0];
-:-:-:-:00 MOV param_A_base_hi, param_A[1];
//ta_shl = ta << 2 =ta*4
-:-:-:-:00 SHL ta_shl, ta, 2;
//trackA0 = ta_shl + param_A[0]
-:-:-:-:00 IADD trackA0.CC, ta_shl, param_A_base_lo;
-:-:-:-:00 IADD.X trackA1, RZ, param_A_base_hi;

-:-:-:-:00 ISETP.LT.AND P5, PT, txa, param_m, PT;

// trackB += (blkB*128 + ldb*tidY + tidX) * 2
-:-:-:-:00 ISCADD txb, blkB, tidX, 7;
//-:-:-:-:00 XMAD.LO2 tb, ldb, tidY, txb;
//-:-:-:-:00 XMAD.LO2 tb, ldbz, blkZ, tb;
//-:-:-:-:00 LEA trackB0.CC, tb, param_B[0], 0x2;
//-:-:-:-:00 LEA.HI.X trackB1, tb, param_B[1], RZ, 0x2;
-:-:-:-:00 IMAD tb, ldb, tidY, txb;
-:-:-:-:00 IMAD tb, ldbz, blkZ, tb;

-:-:-:-:00 MOV param_B_base_lo, param_B[0];
-:-:-:-:00 MOV param_B_base_hi, param_B[1];
-:-:-:-:00 SHL tb_shl, tb, 2;
-:-:-:-:00 IADD trackB0.CC, tb_shl, param_B_base_lo;
-:-:-:-:00 IADD.X trackB1, RZ, param_B_base_hi;

-:-:-:-:00 ISETP.LT.AND P6, PT, txb, param_n, PT;

// writeS = (128*tidY + tidX) * 4
-:-:-:-:00 ISCADD writeS, tidY, tidX, 7;
-:-:-:-:00 SHL writeS, writeS, 2;
-:-:-:-:00 LOP.XOR writeS, writeS, 4x<128*8*2>;

// readAs  = (((tid & 0x70) >> 3) | (tid & 1)) << 4
-:-:-:-:00 LOP.AND tid1, tid, 1;
-:-:-:-:00 LOP.AND readAs, tid, 0x70;
-:-:-:-:00 SHR.U32 readAs, readAs, 3;
-:-:-:-:00 LOP.OR readAs, readAs, tid1;
-:-:-:-:00 SHL readAs, readAs, 4;


// readBs = ((tid128 >> 4) | ((tid >> 1) & 7)) << 4 + 4096;
-:-:-:-:00 LOP.AND tid128, tid, 128;
-:-:-:-:00 BFE.U32 tid7, tid, 0x301; // 3 bits at position 1
-:-:-:-:00 SHR.U32 readBs, tid128, 4;
-:-:-:-:00 LOP.OR readBs, readBs, tid7;
-:-:-:-:00 ISCADD readBs, readBs, 4x<128*8>, 4;

// Grab a seed for this thread
// (blkB*gridDimA*256 + blkA*256 + tid) & (1024*256 - 1)
-:-:-:-:00 MOV flag, param_flags;
//-:-:-:-:00 LOP.AND.NZ P1, RZ, flag, 0x1;
//set P6 as false
//-:-:-:-:00 ISETP.LT.AND P1, PT, RZ, RZ, !PT;
-:-:-:-:00 MOV dimA, gridDimA;
-:-:-:-:00 ISCADD tbid, blkA, tid, 8;
//-:-:-:-:00 XMAD.U16.U16 dimA, blkB, dimA, RZ;
-:-:-:-:00 IMAD dimA, blkB, dimA, RZ;
-:-:-:-:00 ISCADD tbid, dimA, tbid, 8;
-:-:-:-:00 LOP.AND seed, tbid, 1x<2048*32 - 1>;
//-:-:-:-:00 LEA Rand0.CC, seed, param_Rand[0], 2;
//-:-:-:-:00 LEA.HI.X Rand1, seed, param_Rand[1], RZ, 2;
-:-:-:-:00 MOV param_A_base_lo, param_Rand[0];
-:-:-:-:00 MOV param_A_base_hi, param_Rand[1];
-:-:-:-:00 SHL ta_shl, seed, 2;
-:-:-:-:00 IADD Rand0.CC, ta_shl, param_A_base_lo;
-:-:-:-:00 IADD.X Rand1, RZ, param_A_base_hi;

//-:-:-:-:00 @P1 LD.E.CS seed, [Rand];
//</SCHEDULE_BLOCK>

REMAINDER:

//<SCHEDULE_BLOCK>

<CODE>
    our $vec;
    $vec = 1;
    return $vec ? q{
// bDoRemainder = k & 7 && k > 8
//-:-:-:-:00 LOP.AND.NZ P4, RZ, k, 7;
-:-:-:-:00 LOP.AND k_and, k, 7;
-:-:-:-:00 ISETP.EQ.AND P4, PT, k_and, RZ, PT; 
-:-:-:-:00 ISETP.GT.AND P1, PT, k, 8, !P4;

// doLoad = tidY < k && txa|txb < n|m
-:-:-:-:00 ISETP.LT.AND P2, PT, tidY, k, P5;
-:-:-:-:00 ISETP.LT.AND P3, PT, tidY, k, P6;

-:-:-:-:00 @P2 LDG.E.128 loadA, [trackA];
-:-:-:-:00 @P3 LDG.E.128 loadB, [trackB];

T:-:D:S:00 @!P2 LDS.U.128 loadA, [RZ + addr_zero];
T:-:D:S:00 @!P3 LDS.U.128 loadB, [RZ + addr_zero];
T:-:D:S:00 TEXDEPBAR 0x0;
    // Vec 4 and scalar loads
    } : q{

// doLoadA = tidY < k && txa < m
// doLoadB = tidY < k && txb < n
-:-:-:-:00 IADD x1, txa, 1;
-:-:-:-:00 IADD x2, txa, 2;
-:-:-:-:00 IADD x3, txa, 3;
-:-:-:-:00 ISETP.LT.AND P0, PT, tidY, k, P5;
-:-:-:-:00 ISETP.LT.AND P1, PT, x1, param_m, P0;
-:-:-:-:00 ISETP.LT.AND P2, PT, x2, param_m, P0;
-:-:-:-:00 ISETP.LT.AND P3, PT, x3, param_m, P0;

-:-:-:-:00 @P0 LD.E.CI loadA0, [trackA + 4x<0>];
-:-:-:-:00 @P1 LD.E.CI loadA1, [trackA + 4x<1>];
-:-:-:-:00 @P2 LD.E.CI loadA2, [trackA + 4x<2>];
-:-:-:-:00 @P3 LD.E.CI loadA3, [trackA + 4x<3>];

-:-:-:-:00 @!P0 MOV loadA0, RZ;
-:-:-:-:00 @!P1 MOV loadA1, RZ;
-:-:-:-:00 @!P2 MOV loadA2, RZ;
-:-:-:-:00 @!P3 MOV loadA3, RZ;

-:-:-:-:00 IADD y1, txb, 1;
-:-:-:-:00 IADD y2, txb, 2;
-:-:-:-:00 IADD y3, txb, 3;
-:-:-:-:00 ISETP.LT.AND P0, PT, tidY, k, P6;
-:-:-:-:00 ISETP.LT.AND P1, PT, y1, param_n, P0;
-:-:-:-:00 ISETP.LT.AND P2, PT, y2, param_n, P0;
-:-:-:-:00 ISETP.LT.AND P3, PT, y3, param_n, P0;

-:-:-:-:00 @P0 LD.E.CI loadB0, [trackB + 4x<0>];
-:-:-:-:00 @P1 LD.E.CI loadB1, [trackB + 4x<1>];
-:-:-:-:00 @P2 LD.E.CI loadB2, [trackB + 4x<2>];
-:-:-:-:00 @P3 LD.E.CI loadB3, [trackB + 4x<3>];

-:-:-:-:00 @!P0 MOV loadB0, RZ;
-:-:-:-:00 @!P1 MOV loadB1, RZ;
-:-:-:-:00 @!P2 MOV loadB2, RZ;
-:-:-:-:00 @!P3 MOV loadB3, RZ;

-:-:-:-:00 ISETP.GT.AND P1, PT, k, 8, PT;
    };
</CODE>

//</SCHEDULE_BLOCK>

-:-:-:-:00 STS.128 [writeS + 4x<0*128>], loadA0;

-:-:-:-:00 IADD trackA0.CC, trackA0, param_lda8;
-:-:-:-:00 IADD.X trackA1, trackA1, RZ;

-:-:-:-:00 STS.128 [writeS + 4x<8*128>], loadB0;

-:-:-:-:00 IADD trackB0.CC, trackB0, param_ldb8;

-:-:-:-:00 LOP.XOR readAs, readAs, 4x<128*8*2>;
-:-:-:-:00 LOP.XOR readBs, readBs, 4x<128*8*2>;
-:-:-:-:00 BAR.SYNC 0;
-:-:-:-:00 LOP.XOR writeS, writeS, 4x<128*8*2>;

-:-:-:-:00 IADD.X trackB1, trackB1, RZ;


<CODE>
    our $vec;
    $vec = 1;
    my $k_end = $vec ? 16 : 24;
#our @top = ("-:-:-:-:00 ISETP.GE.AND P2, PT, k, $k_end, P5;\n");
our @top = ("");
    our %insert =
    (
j0c29 => "-:-:-:-:00 ISETP.GE.AND P2, PT, k, $k_end, P5;\n",
j0c62 => "-:-:-:-:00 ISETP.GE.AND P3, PT, k, $k_end, P6;\n",
j1c62 => "-:-:-:-:00 ISETP.GE.AND P0, PT, k, $k_end, PT;\n",

        ($vec ? 
            (
j0c62 => "-:G:D:-:00 \@P2 LDG.E.128 loadA, [trackA];\n",
j1c53 => "-:G:D:-:00 \@P3 LDG.E.128 loadB, [trackB];\n",
            ) : 
            (
j0c10 => "-:-:-:-:00 \@P2 LD.E.CI loadA0, [trackA + 4x<0>];\n",
j0c29 => "-:-:-:-:00 \@P2 LD.E.CI loadA1, [trackA + 4x<1>];\n",
j0c31 => "-:-:-:-:00 \@P2 LD.E.CI loadA2, [trackA + 4x<2>];\n",
j0c33 => "-:-:-:-:00 \@P2 LD.E.CI loadA3, [trackA + 4x<3>];\n",

j0c35 => "-:-:-:-:00 \@P3 LD.E.CI loadB0, [trackB + 4x<0>];\n",
j1c29 => "-:-:-:-:00 \@P3 LD.E.CI loadB1, [trackB + 4x<1>];\n",
j1c31 => "-:-:-:-:00 \@P3 LD.E.CI loadB2, [trackB + 4x<2>];\n",
j1c33 => "-:-:-:-:00 \@P3 LD.E.CI loadB3, [trackB + 4x<3>];\n",
            )
        ),

j6c29 => "-:-:-:S:00 \@P0 STS.128 [writeS + 4x<0*128>], loadA0;\n",

j2c53 => "-:-:-:-:00 \@P2 IADD trackA0.CC, trackA0, param_lda8;\n",
j3c53 => "-:-:-:-:00 \@P2 IADD.X trackA1, trackA1, RZ;\n",

j6c53 => "-:-:-:S:00 \@P0 STS.128 [writeS + 4x<8*128>], loadB0;\n",

j3c61 => "-:-:-:-:00 \@P3 IADD trackB0.CC, trackB0, param_ldb8;\n",
j4c61 => "-:-:-:-:00 \@P3 IADD.X trackB1, trackB1, RZ;\n",

j6c63 => "-:-:D:-:00 \@P0 BAR.SYNC 0;\n" ,
j6c61 => "-:-:-:-:00 \@P0 LOP.XOR readAs, readAs, 4x<128*8*2>;\n" ,
j6c62 => "-:-:-:-:00 \@P0 LOP.XOR readBs, readBs, 4x<128*8*2>;\n" ,
j7c62 => "-:-:-:-:00 \@P0 LOP.XOR writeS, writeS, 4x<128*8*2>;\n" ,
j5c29  => "-:-:-:-:00 IADD32I k, k, -8;\n",
#j5c63  => "-:-:-:-:00 TEXDEPBAR 0x0; \n",
j5c63  => "T:-:D:S:00 TEXDEPBAR 0x0; \n",

j7c63 => "-:-:D:-:00 \@P0 BRA.U LOOP;\n" .
"-:-:D:-:00 \@P1 BRA.U REMAINDER;\n",
    );
    return;
</CODE>

<INCLUDE file="sgemm_common_128x128.sass">
